from typing import List, Dict, Optional
from pathlib import Path
import json
import random
import time
from tqdm import tqdm
from openai import OpenAI
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class Config:
    SAMPLE_SIZE = 300
    API_RATE_LIMIT = 2
    MAX_RETRIES = 3
    
    BASE_PATH = Path("/process_COT/csqa")
    OUTPUT_FILE = BASE_PATH / "reasoning_output_300_correct.txt"
    PROGRESS_FILE = BASE_PATH / "progress.json"
    
    API_KEY = ""
    BASE_URL = ""
    MODEL_NAME = "gpt-4o"

class QuestionProcessor:
    def __init__(self):
        self.client = OpenAI(
            api_key=Config.API_KEY,
            base_url=Config.BASE_URL
        )
        self.processed_questions = self._load_progress()

    def _load_progress(self) -> set:
        if Config.PROGRESS_FILE.exists():
            with open(Config.PROGRESS_FILE, 'r') as f:
                return set(json.load(f))
        return set()

    def _save_progress(self, question_id: str):
        self.processed_questions.add(question_id)
        with open(Config.PROGRESS_FILE, 'w') as f:
            json.dump(list(self.processed_questions), f)

    def get_completion(self, prompt: str, retries: int = Config.MAX_RETRIES) -> Optional[str]:
        for attempt in range(retries):
            try:
                response = self.client.chat.completions.create(
                    model=Config.MODEL_NAME,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant that provides step-by-step reasoning."},
                        {"role": "user", "content": prompt}
                    ]
                )
                return response.choices[0].message.content.strip()
            except Exception as e:
                logging.warning(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                if attempt < retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    logging.error(f"API call final failure: {str(e)}")
                    return None

    def format_question(self, qa_item: Dict) -> str:
        question = qa_item['question']
        choices = qa_item['choices']
        formatted_choices = [
            f"{label}: {text}" 
            for label, text in zip(choices['label'], choices['text'])
        ]
        
        return (
            f"Question: {question}\n"
            f"Answer Choices: {' | '.join(formatted_choices)}\n"
            f"Reference Answer: {qa_item['answerKey']}"
        )

    def generate_prompt(self, formatted_q: str) -> str:
        return f"""You are a helpful assistant who provides step-by-step reasoning for questions. 
Given the reference answer, generate a clear and logical reasoning process that naturally leads to that answer.

{formatted_q}

Format your response exactly as:
Let's think step by step.
[Your step-by-step reasoning that naturally leads to the reference answer]
The answer is [reference answer letter]."""

    def format_response(self, question: str, response: str) -> str:
        lines = []
        question_lines = question.strip().split('\n')
        lines.extend(line.strip() for line in question_lines if not line.startswith('Reference Answer:'))
        
        in_reasoning = False
        reasoning_lines = []
        
        for line in response.strip().split('\n'):
            line = line.strip()
            if not line:
                continue
                
            if line.startswith(('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.')):
                line = line[line.find(' ')+1:].strip()
                
            if line.startswith("Let's think"):
                in_reasoning = True
                lines.append(line)
            elif line.startswith("The answer is"):
                in_reasoning = False
                if reasoning_lines:
                    lines.extend(reasoning_lines)
                lines.append(line)
            elif in_reasoning:
                reasoning_lines.append(line)
        
        return '\n'.join(lines)

    def process_questions(self, questions: List[Dict]):
        Config.OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
        
        selected_questions = random.sample(questions, Config.SAMPLE_SIZE)
        
        with open(Config.OUTPUT_FILE, 'a', encoding='utf-8') as f:
            for qa in tqdm(selected_questions, desc="Processing questions"):
                if qa['id'] in self.processed_questions:
                    logging.info(f"Skipping already processed question {qa['id']}")
                    continue
                
                formatted_q = self.format_question(qa)
                prompt = self.generate_prompt(formatted_q)
                
                if response := self.get_completion(prompt):
                    formatted_output = self.format_response(formatted_q, response)
                    f.write(f"{formatted_output}\n\n")
                    self._save_progress(qa['id'])
                    
                time.sleep(Config.API_RATE_LIMIT)

def main():
    try:
        with open(Config.BASE_PATH / "csqa_300.jsonl", 'r', encoding='utf-8') as f:
            questions = [json.loads(line) for line in f]
        
        logging.info(f"loaded {len(questions)} questions")
        
        processor = QuestionProcessor()
        processor.process_questions(questions)
        
        logging.info("Processing completed!")
        
    except Exception as e:
        logging.error(f"Program error: {str(e)}")

if __name__ == "__main__":
    main()